import scrublet as scr
import scipy.io
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import argparse
parser = argparse.ArgumentParser(description='manual to this script')
parser.add_argument('--gene_cell_matrix_input_file', type=str, default = None)
parser.add_argument('--expected_doublet_rate', type=str, default = None)
parser.add_argument('--predicted_output_file', type=str, default = None)
args = parser.parse_args()
counts_matrix = pd.read_csv(args.gene_cell_matrix_input_file,sep='\t',index_col=0).T
edr = float(args.expected_doublet_rate)  
scrub = scr.Scrublet(counts_matrix, expected_doublet_rate=edr)
doublet_scores, predicted_doublets = scrub.scrub_doublets(min_counts=1, min_cells=1, min_gene_variability_pctl=85,n_prin_comps=30)
predict_table = pd.DataFrame(scrub.doublet_scores_obs_, index=counts_matrix._stat_axis.values.tolist(), columns=['doublet_scores'])


if scrub.predicted_doublets_ is None:
	predicted_doublets_final = scrub.call_doublets(threshold = 0.25)
	#scrub.plot_histogram()
	#plt.savefig("scrublet.png")
	predict_label = list(scrub.predicted_doublets_)
else:
	predict_label = list(scrub.predicted_doublets_)

	
predict_zero_one = []
for i in predict_label:
	if i:
		predict_zero_one.append('1')
	else:
		predict_zero_one.append('0')

predict_table['predicted_doublets'] = predict_zero_one
predict_table['Barcode'] = predict_table._stat_axis.values.tolist()
predict_table = predict_table.reindex(columns=['Barcode','doublet_scores','predicted_doublets'])
predict_table.to_csv(args.predicted_output_file, sep="\t", index=False)